package face3wap;

import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.ElementWiseVertex;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.graph.StackVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationLReLU;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.impl.NormalDistribution;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;

import javax.swing.*;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;

/**
 *  faceswap 换脸
 *  1、A脸   EA -> DA -> DISA
 *  2、B脸   EB -> DB -> DISB
 *  3、EA和EB的参数在每次训练完成后都同步一下。使A和B有相同的Ecoder
 *  4、B脸换成A脸 A脸 -> EB -> DB ->B的特征+A
 *  5、保存跟换后的图片到文件夹中
 *
 *  存在的问题
 *  1、score 在循环600多次以后就不会变动了。 WGNA优化
 *  2、cup训练相当缓慢无法进行验证。
 */

public class Img3GanModel {
    private static final double LEARNING_RATE = 0.01;
    public static final int batch = 8;
    private static final double GRADIENT_THRESHOLD = 100.0;
    public static String panfu = "E:\\";
    private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build();
    static String modelA = panfu+"face\\model\\imgA.zip";
    static String modelB = panfu+"face\\model\\imgB.zip";
    private static final int seed = 12345;

    static int height = 64; // 输入图像高度
    static int width = 64; // 输入图像宽度
    static int channels = 3; // 输入图像通道数
    static int[] inputShape = new int[] {channels, width, height};
    static double lr = 0.01;
    private static JFrame frame;
    private static JPanel panel;

    public static ComputationGraphConfiguration encoder(){
        return new NeuralNetConfiguration.Builder()
                .seed(seed)
                .updater( new Sgd(lr))
                /* .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
                 .gradientNormalizationThreshold(GRADIENT_THRESHOLD)*/
                .weightInit(WeightInit.XAVIER)
                .activation(Activation.IDENTITY)
                .graphBuilder()
                .addInputs("input1")
                .setInputTypes(InputType.convolutional(height, width, channels))
                //图片过滤
                // valid (w-f + 1)/s  same ： w/s
                //其中N：输出大小 W：输入大小 F：卷积核大小 P：填充值的大小 S：步长大小
                // g1 = 64 / 1 = 64 * 64  * 64
                .addLayer("g1", new ConvolutionLayer.Builder().kernelSize(5,5).stride(1, 1).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(64).build(), "input1")
                //conv_block
                // g2 = 64  / 2  = 32 * 32  * 128  out = in/s
                .addLayer("g2", new ConvolutionLayer.Builder().kernelSize(3,3).stride(2, 2).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(128).build(), "g1")
                .addLayer("g2_relu", new ActivationLayer.Builder().activation(Activation.RELU).build(),"g2")
                // g3 = 32 / 2  = 16 * 16  * 256
                .addLayer("g3", new ConvolutionLayer.Builder().kernelSize(3,3).stride(2, 2).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(256).build(), "g2_relu")
                .addLayer("g3_relu", new ActivationLayer.Builder().activation(Activation.RELU).build(),"g3")
                // g4 = 16 / 2  = 8 * 8  * 512
                .addLayer("g4", new ConvolutionLayer.Builder().kernelSize(3,3).stride(2, 2).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(512).build(), "g3_relu")
                .addLayer("g4_relu", new ActivationLayer.Builder().activation(Activation.RELU).build(),"g4")
                // g5 = 8 / 2  = 4 * 4  * 1024
                .addLayer("g5", new ConvolutionLayer.Builder().kernelSize(3,3).stride(2,2).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(1024).build(), "g4_relu")
                .addLayer("g5_relu", new ActivationLayer.Builder().activation(Activation.RELU).build(),"g5")
                //转成1维进行全连接

                .addVertex("Flatten",new ReshapeVertex(batch,  4 * 4  * 1024),"g5_relu")
                //.addVertex("Flatten",new PreprocessorVertex(new KerasFlattenRnnPreprocessor(1024,1)),"g5_relu")
                .addLayer("g6", new DenseLayer.Builder().nOut(1024).build(), "Flatten")
                .addLayer("g7", new DenseLayer.Builder().nOut(4 * 4 * 1024) .build(), "g6")
                //转成4维进行cnn
                .addVertex("reshape",new ReshapeVertex(batch,1024,4, 4),"g7")
                // g8 = 4 / 1  = 4 * 4  * 512 * 4  = 8 * 8 * 512
                .addLayer("g8", new ConvolutionLayer.Builder().kernelSize(3,3).stride(1, 1).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(512 * 4).build(), "reshape")
                .addLayer("g8_relu", new ActivationLayer.Builder().activation(new ActivationLReLU(0.1)).build(),"g8")

                //转换为 8 * 8 * 512
                .addVertex("Shuffle1",new ShuffleVertex(batch,512,8, 8),"g8_relu")
                //图片Encoder 完成
                .addLayer("out", new OutputLayer.Builder().nIn(512).nOut(512).build(), "Shuffle1")
                .setOutputs("out")
                .build();
    }


    public static ComputationGraphConfiguration gan(){
        return new NeuralNetConfiguration.Builder()
                .seed(seed)
                .updater( new Sgd(lr))
                /* .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
                 .gradientNormalizationThreshold(GRADIENT_THRESHOLD)*/
                .weightInit(WeightInit.XAVIER)
                .activation(Activation.IDENTITY)
                .graphBuilder()
                .addInputs("input1","input2")
                .setInputTypes(InputType.convolutional(height, width, channels),InputType.convolutional(height, width, channels))
                //图片过滤
                // valid (w-f + 1)/s  same ： w/s
                //其中N：输出大小 W：输入大小 F：卷积核大小 P：填充值的大小 S：步长大小
                // g1 = 64 / 1 = 64 * 64  * 64
                .addLayer("g1", new ConvolutionLayer.Builder().kernelSize(5,5).stride(1, 1).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(64).build(), "input1")
                //conv_block
                // g2 = 64  / 2  = 32 * 32  * 128  out = in/s
                .addLayer("g2", new ConvolutionLayer.Builder().kernelSize(3,3).stride(2, 2).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(128).build(), "g1")
                .addLayer("g2_relu", new ActivationLayer.Builder().activation(Activation.RELU).build(),"g2")
                // g3 = 32 / 2  = 16 * 16  * 256
                .addLayer("g3", new ConvolutionLayer.Builder().kernelSize(3,3).stride(2, 2).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(256).build(), "g2_relu")
                .addLayer("g3_relu", new ActivationLayer.Builder().activation(Activation.RELU).build(),"g3")
                // g4 = 16 / 2  = 8 * 8  * 512
                .addLayer("g4", new ConvolutionLayer.Builder().kernelSize(3,3).stride(2, 2).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(512).build(), "g3_relu")
                .addLayer("g4_relu", new ActivationLayer.Builder().activation(Activation.RELU).build(),"g4")
                // g5 = 8 / 2  = 4 * 4  * 1024
                .addLayer("g5", new ConvolutionLayer.Builder().kernelSize(3,3).stride(2,2).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(1024).build(), "g4_relu")
                .addLayer("g5_relu", new ActivationLayer.Builder().activation(Activation.RELU).build(),"g5")
                //转成1维进行全连接

                .addVertex("Flatten",new ReshapeVertex(batch,  4 * 4  * 1024),"g5_relu")
                //.addVertex("Flatten",new PreprocessorVertex(new KerasFlattenRnnPreprocessor(1024,1)),"g5_relu")
                .addLayer("g6", new DenseLayer.Builder().nOut(1024).build(), "Flatten")
                .addLayer("g7", new DenseLayer.Builder().nOut(4 * 4 * 1024) .build(), "g6")
                //转成4维进行cnn
                .addVertex("reshape",new ReshapeVertex(batch,1024,4, 4),"g7")
                // g8 = 4 / 1  = 4 * 4  * 512 * 4  = 8 * 8 * 512
                .addLayer("g8", new ConvolutionLayer.Builder().kernelSize(3,3).stride(1, 1).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(512 * 4).build(), "reshape")
                .addLayer("g8_relu", new ActivationLayer.Builder().activation(new ActivationLReLU(0.1)).build(),"g8")

                //转换为 8 * 8 * 512
                .addVertex("Shuffle1",new ShuffleVertex(batch,512,8, 8),"g8_relu")
                //图片Encoder 完成

                //图片Decoder_ps  初始输入 8 * 8 * 512 最终输出 64 * 64 * 4
                //.addVertex("reshape1",new ReshapeVertex(batch,8, 8, 512),"Shuffle1")
                // g9 = 8 / 1  = 8 * 8  * 256 * 4 -> 16 * 16 * 256
                .addLayer("g9", new ConvolutionLayer.Builder().kernelSize(3,3).stride(1, 1).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(256 * 4).build(), "Shuffle1")
                .addLayer("g9_relu", new ActivationLayer.Builder().activation(new ActivationLReLU(0.1)).build(),"g9")
                .addVertex("Shuffle2",new ShuffleVertex(batch,256,16, 16),"g9_relu")
                // g10 = 16 / 1 = 16 * 16  * 128 * 4 -> 32 * 32 * 128
                .addLayer("g10", new ConvolutionLayer.Builder().kernelSize(3,3).stride(1, 1).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(128 * 4).build(), "Shuffle2")
                .addLayer("g10_relu", new ActivationLayer.Builder().activation(new ActivationLReLU(0.1)).build(),"g10")
                .addVertex("Shuffle3",new ShuffleVertex(batch,128,32, 32 ),"g10_relu")
                // g11 = 32 / 1  = 32 * 32  * 64 * 4  = 64 * 64 * 64
                .addLayer("g11", new ConvolutionLayer.Builder().kernelSize(3,3).stride(1, 1).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(64 * 4).build(), "Shuffle3")
                .addLayer("g11_relu", new ActivationLayer.Builder().activation(new ActivationLReLU(0.1)).build(),"g11")
                .addVertex("Shuffle4",new ShuffleVertex(batch,64,64, 64 ),"g11_relu")

                //res_block -1
                // g12 = 64 / 1 + 1 = 64 * 64 * 64
                .addLayer("g12", new ConvolutionLayer.Builder().kernelSize(3,3).stride(1, 1).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(64).build(), "Shuffle4")
                .addLayer("g12_relu", new ActivationLayer.Builder().activation(new ActivationLReLU(0.2)).build(),"g12")
                // g13 = 64 / 1  = 64 * 64 * 64
                .addLayer("g13", new ConvolutionLayer.Builder().kernelSize(3,3).stride(1, 1).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(64).build(), "g12_relu")
                .addVertex("add1",new ElementWiseVertex(ElementWiseVertex.Op.Add),"g13","Shuffle4")
                .addLayer("g13_relu", new ActivationLayer.Builder().activation(new ActivationLReLU(0.2))
                        .build(), "add1")

                //res_block -2
                // g15 =64 / 1  = 64 * 64 * 64
                .addLayer("g14", new ConvolutionLayer.Builder().kernelSize(3,3).stride(1, 1).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(64).build(), "g13_relu")
                .addLayer("g14_relu", new ActivationLayer.Builder().activation(new ActivationLReLU(0.2)).build(),"g14")
                // g16 = 64/ 1 = 64 * 64 * 64
                .addLayer("g15", new ConvolutionLayer.Builder().kernelSize(3,3).stride(1, 1).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(64).build(), "g14_relu")
                // 64*64*128
                .addVertex("add2",new ElementWiseVertex(ElementWiseVertex.Op.Add),"g15","g13_relu")
                .addLayer("g15_relu", new ActivationLayer.Builder().activation(new ActivationLReLU(0.2))
                        .build(), "add2")
                // g18 = 64 / 1 = 64 * 64 * 1
                .addLayer("g16", new ConvolutionLayer.Builder().kernelSize(5,5).stride(1, 1).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(1).activation(Activation.SIGMOID).build(), "g15_relu")
                // g19 = 64 / 1  = 64* 64 * 3
                .addLayer("g17", new ConvolutionLayer.Builder().kernelSize(5,5).stride(1, 1).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(3).activation(Activation.TANH).build(), "g15_relu")

                .addVertex("merge2",new MergeVertex(),"g16","g17")
                .addVertex("merge3",new MergeVertex(),"g16","input2")
                .addVertex("stack", new StackVertex(), "merge3", "merge2")
                // dis  8 * 4 * 64* 64
                .addLayer("d1", new ConvolutionLayer.Builder(new int[]{4,4},new int[]{2,2}).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(64).activation(new ActivationLReLU(0.2)).build(), "stack")
                //8 * 128 * 32 * 32
                .addLayer("d2", new ConvolutionLayer.Builder(new int[]{4,4},new int[]{2,2}).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(128).activation(new ActivationLReLU(0.2)).build(), "d1")
                //8 * 256 * 16 * 16
                .addLayer("d3", new ConvolutionLayer.Builder(new int[]{4,4},new int[]{2,2}).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(256).activation(new ActivationLReLU(0.2)).build(), "d2")
                //8 * 1 * 8 * 8
                .addLayer("d4", new ConvolutionLayer.Builder(new int[]{4, 4}).stride(1, 1).convolutionMode(ConvolutionMode.Same).hasBias(false)
                        .nOut(1).build(), "d3")
                .addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(64).nOut(1)
                        .activation(Activation.SIGMOID).build(), "d4")
                .setOutputs("out")
                .build();

    }


    public static void main(String... args) throws Exception {
        Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);
        ComputationGraph netA = null;
        ComputationGraph netB = null;
        if (new File(modelA).exists()) {
            netA = ComputationGraph.load(new File(modelA), true);
        }else{
            netA =  new ComputationGraph(gan());
        }

        if (new File(modelB).exists()) {
            netB = ComputationGraph.load(new File(modelB), true);
        }else{
            netB =  new ComputationGraph(gan());
        }
        netA.init();
        netB.init();
        System.out.println(netA.summary());
        System.out.println(netB.summary());
        netA.setListeners(new ScoreIterationListener(10));
        netB.setListeners(new ScoreIterationListener(10));

        String inputDataDir = panfu+"face/yzm";
        //准备A的照片 准备B的照片
        File trainDataFileA = new File(inputDataDir + "/A"),
             trainDataFileB = new File(inputDataDir + "/B");

        FileSplit trainSplitA = new FileSplit(trainDataFileA, NativeImageLoader.ALLOWED_FORMATS),
                  trainSplitB = new FileSplit(trainDataFileB, NativeImageLoader.ALLOWED_FORMATS);

        ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
        ImageRecordReader trainRRA = new ImageRecordReader(height, width, channels,labelMaker),
                          trainRRB = new ImageRecordReader(height, width, channels,labelMaker);

        trainRRA.initialize(trainSplitA);
        trainRRB.initialize(trainSplitB);

        DataSetIterator trainDataA = new RecordReaderDataSetIterator(trainRRA, batch, 1, 1),
                        trainDataB = new RecordReaderDataSetIterator(trainRRB, batch, 1, 1);

        DataNormalization scalerA = new ImagePreProcessingScaler(0, 1),
                          scalerB = new ImagePreProcessingScaler(0, 1);

        scalerA.fit(trainDataA);
        scalerB.fit(trainDataB);

        trainDataA.setPreProcessor(scalerA);
        trainDataB.setPreProcessor(scalerB);

        trainDataA.reset();
        trainDataB.reset();

        //判断标签 假的数据为1 真的为0
        INDArray labelD = Nd4j.vstack(Nd4j.ones(batch, 1), Nd4j.zeros(batch, 1));
        INDArray labelG = Nd4j.ones(batch * 2, 1);

        for (int i = 1; i<= 100000; i++) {
            if (!trainDataA.hasNext()) {
                trainDataA.reset();
            }
            if (!trainDataB.hasNext()) {
                trainDataB.reset();
            }
            INDArray trueExpA = trainDataA.next().getFeatures();
            INDArray trueExpB = trainDataB.next().getFeatures();

            //图片数据展示
            /*INDArray[] samples1 = new INDArray[1];
            samples1[0] = trueExp;// .reshape(20,28,28);
            visualizeSrc(samples1);*/
            //随机生成的假数据
            INDArray z = Nd4j.rand(new NormalDistribution(),new long[] { batch, channels,height,width });

            //训练A模型
            MultiDataSet dataSetDA = new MultiDataSet(new INDArray[] {z,trueExpA},new INDArray[] { labelD });
            for(int m=0;m<10;m++){
                System.out.println("训练A的第"+i+":"+m);
                trainD(netA, dataSetDA);
            }
            MultiDataSet dataSetGA = new MultiDataSet(new INDArray[] { z,trueExpA },new INDArray[] { labelG });
            trainG(netA, dataSetGA);
            updateA2B(netA,netB);

            //训练B模型
            MultiDataSet dataSetDB = new MultiDataSet(new INDArray[] {z,trueExpB},new INDArray[] { labelD });
            for(int m=0;m<10;m++){
                System.out.println("训练B的第"+i+":"+m);
                trainD(netB, dataSetDB);
            }
            MultiDataSet dataSetGB = new MultiDataSet(new INDArray[] { z,trueExpA },new INDArray[] { labelG });
            trainG(netB, dataSetGB);
            updateB2A(netB,netA);

           /* if (i % 1 == 0) {
                INDArray fakeIn = Nd4j.rand(new NormalDistribution(),new long[] {  batch, channels,height,width  });
                INDArray[] samples = new INDArray[1];
                MultiDataSet fakeSet2 = new MultiDataSet(fakeIn, Nd4j.ones(batch, 1));
                INDArray input = fakeSet2.getFeatures(0);
                Map<String, INDArray> map = netA.feedForward(new INDArray[] {input,input}, false);
                //samples[0] = map.get("g17");// .reshape(20,28,28);
                samples[0] = map.get("merge2");// .reshape(20,28,28);
                //samples[0] = trueExp;// .reshape(20,28,28);
                //visualizeSrc(samples);
                visualize(samples);
            }*/

            if (i % 1 == 0) {
                netA.save(new File(modelA), true);
                netB.save(new File(modelB), true);
            }
        }
    }

    private static void updateA2B(ComputationGraph ganA, ComputationGraph ganB) {
        System.out.println("参数：A->B");
        ganB.getLayer("g1").setParams(ganA.getLayer("g1").params());
        ganB.getLayer("g2").setParams(ganA.getLayer("g2").params());
        ganB.getLayer("g3").setParams(ganA.getLayer("g3").params());
        ganB.getLayer("g4").setParams(ganA.getLayer("g4").params());
        ganB.getLayer("g5").setParams(ganA.getLayer("g5").params());
    }
    private static void updateB2A(ComputationGraph ganA, ComputationGraph ganB) {
        System.out.println("参数：B->A");
        ganA.getLayer("g1").setParams(ganB.getLayer("g1").params());
        ganA.getLayer("g2").setParams(ganB.getLayer("g2").params());
        ganA.getLayer("g3").setParams(ganB.getLayer("g3").params());
        ganA.getLayer("g4").setParams(ganB.getLayer("g4").params());
        ganA.getLayer("g5").setParams(ganB.getLayer("g5").params());
    }

    // 判别模型  D(x)
    public static void trainD(ComputationGraph net, MultiDataSet dataSet) {
        net.setLearningRate("g1", 0);
        net.setLearningRate("g2", 0);
        net.setLearningRate("g3", 0);
        net.setLearningRate("g4", 0);
        net.setLearningRate("g5", 0);
        net.setLearningRate("g6", 0);
        net.setLearningRate("g7", 0);
        net.setLearningRate("g8", 0);
        net.setLearningRate("g9", 0);
        net.setLearningRate("g10", 0);
        net.setLearningRate("g11", 0);
        net.setLearningRate("g12", 0);
        net.setLearningRate("g13", 0);
        net.setLearningRate("g14", 0);
        net.setLearningRate("g15", 0);
        net.setLearningRate("g16", 0);
        net.setLearningRate("g17", 0);
        net.setLearningRate("d1", lr);
        net.setLearningRate("d2", lr);
        net.setLearningRate("d3", lr);
        net.setLearningRate("d4", lr);
        net.setLearningRate("out", lr);
        net.fit(dataSet);
    }
    //生成模型 g(z)
    public static void trainG(ComputationGraph net, MultiDataSet dataSet) {
        net.setLearningRate("g1", lr);
        net.setLearningRate("g2", lr);
        net.setLearningRate("g3", lr);
        net.setLearningRate("g4", lr);
        net.setLearningRate("g5", lr);
        net.setLearningRate("g6", lr);
        net.setLearningRate("g7", lr);
        net.setLearningRate("g8", lr);
        net.setLearningRate("g9", lr);
        net.setLearningRate("g10", lr);
        net.setLearningRate("g11", lr);
        net.setLearningRate("g12", lr);
        net.setLearningRate("g13", lr);
        net.setLearningRate("g14", lr);
        net.setLearningRate("g15", lr);
        net.setLearningRate("g16", lr);
        net.setLearningRate("g17", lr);
        net.setLearningRate("d1", 0);
        net.setLearningRate("d2", 0);
        net.setLearningRate("d3", 0);
        net.setLearningRate("d4", 0);
        net.setLearningRate("out", 0);
        net.fit(dataSet);
    }




    private static void visualize(INDArray[] samples) {
        if (frame == null) {
            frame = new JFrame();
            frame.setTitle("Viz");
            frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
            frame.setLayout(new BorderLayout());

            panel = new JPanel();

            panel.setLayout(new GridLayout(8 / 4, 1, 8, 8));
            frame.add(panel, BorderLayout.CENTER);
            frame.setVisible(true);
        }

        panel.removeAll();

        for (INDArray sample : samples) {
            long[] shape = sample.shape();
            long batc = shape[0];
            for(long i=0;i<batc;i++){
                panel.add(getImage(sample));
            }
        }

        frame.revalidate();
        frame.pack();
    }

    private static JLabel getImage(INDArray tensor) {
        //System.out.println(tensor.length());
       /* BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_INT_ARGB_PRE);
        tensor = tensor.reshape(3,28,28);

        for (int i = 0; i < 2352; i++) {
            int pixel = (int)(((tensor.getDouble(i) + 1) * 2) * 255);
            bi.getRaster().setSample(i % 84, i / 84, 0, pixel);
        }*/
        ImageIcon orig = new ImageIcon(imageFromINDArray(tensor));
        Image imageScaled = orig.getImage().getScaledInstance(64,  64, Image.SCALE_REPLICATE);

        ImageIcon scaled = new ImageIcon(imageScaled);

        return new JLabel(scaled);
    }
    private static BufferedImage imageFromINDArray(INDArray array) {
       // array = array.reshape(8,3,64,64);
        long[] shape = array.shape();
        long height = shape[2];
        long width = shape[3];
        BufferedImage image = new BufferedImage((int)width, (int)height, BufferedImage.TYPE_INT_RGB);
        for (int x = 0; x < width; x++) {
            for (int y = 0; y < height; y++) {
                double red1 = array.getDouble(0,3, y, x);
                double green1 = array.getDouble( 0,2,y, x);
                double blue1 = array.getDouble( 0,1,y, x);
                double alpha1 = array.getDouble(0,0, y, x);
                System.out.println(alpha1 + "-----" + red1 + "--" + green1 + "===" + blue1);
                //handle out of bounds pixel values
                int alpha = Math.min((int) (alpha1  * 255 ), 255);
                int red = Math.min((int) (red1  * 1000 ) , 255);
                int green = Math.min((int)(green1  * 1000 ) , 255);
                int blue = Math.min((int)(blue1 * 1000 ), 255);
             //   System.out.println(alpha + "-----" + red + "--" + green + "===" + blue);
                alpha = Math.max(alpha, 0);
                red = Math.max(red, 0);
                green = Math.max(green, 0);
                blue = Math.max(blue, 0);
                image.setRGB(x, y, new Color(red, green, blue,alpha).getRGB());
            }
        }
        return image;
    }


    /**
     * 显示加载的数据
     * @param samples
     */
    private static void visualizeSrc(INDArray[] samples) {
        if (frame == null) {
            frame = new JFrame();
            frame.setTitle("Viz");
            frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
            frame.setLayout(new BorderLayout());

            panel = new JPanel();

            panel.setLayout(new GridLayout(samples.length / 3, 1, 8, 8));
            frame.add(panel, BorderLayout.CENTER);
            frame.setVisible(true);
        }

        panel.removeAll();

        for (INDArray sample : samples) {
            panel.add(getImageSrc(sample));
        }

        frame.revalidate();
        frame.pack();
    }
    private static JLabel getImageSrc(INDArray tensor) {
        ImageIcon orig = new ImageIcon(imageFromINDArraySrc(tensor));
        Image imageScaled = orig.getImage().getScaledInstance((8 * 28), (8 * 28), Image.SCALE_REPLICATE);

        ImageIcon scaled = new ImageIcon(imageScaled);

        return new JLabel(scaled);
    }
    private static BufferedImage imageFromINDArraySrc(INDArray array) {
        long[] shape = array.shape();
        long height = shape[2];
        long width = shape[3];
        BufferedImage image = new BufferedImage((int)width, (int)height, BufferedImage.TYPE_INT_RGB);
        for (int x = 0; x < width; x++) {
            for (int y = 0; y < height; y++) {
                double red1 = array.getDouble(0,2, y, x);
                double green1 = array.getDouble( 0,1,y, x);
                double blue1 = array.getDouble( 0,0,y, x);
                System.out.println(red1 + "--" + green1 + "===" + blue1);
                //handle out of bounds pixel values
                int red = Math.min((int) (red1 * 255), 255);
                int green = Math.min((int)(green1 * 255), 255);
                int blue = Math.min((int)(blue1 * 255), 255);
                red = Math.max(red, 0);
                green = Math.max(green, 0);
                blue = Math.max(blue, 0);
                image.setRGB(x, y, new Color(red, green, blue,255).getRGB());
            }
        }
        return image;
    }

}
